Density Ratios¶
In [1]:
import numpy as np
import seaborn as sns
#Generate Random Samples from a Mixture of two normal distributions
# Define the parameters for the two normal distributions
mean1 = 0
variance1 = 1
mean2 = 5
variance2 = 2
# Generate random draws from the mixture distribution
n = 1000
weights = [0.95, 0.05] # Equal weights for the two distributions
np.random.seed(42)
# Generate random indices to select which distribution to sample from
indices = np.random.choice([0, 1], size=n, p=weights)
# Generate random samples from the mixture distribution
p = np.zeros(n)
for i in range(n):
if indices[i] == 0:
p[i] = np.random.normal(mean1, np.sqrt(variance1))
else:
p[i] = np.random.normal(mean2, np.sqrt(variance2))
import matplotlib.pyplot as plt
fig, axs = plt.subplots(1, 2, figsize=(12, 4))
# Plot histogram of samples
axs[0].hist(p, bins=10, color='skyblue', edgecolor='black')
axs[0].set_title('Histogram of Samples')
axs[0].set_xlabel('Value')
axs[0].set_ylabel('Frequency')
# Plot rug plot of samples
sns.rugplot(p, ax=axs[1], height=0.2, color='skyblue')
axs[1].set_title('Rug Plot of Samples')
axs[1].set_xlabel('Value')
axs[1].set_ylabel('')
plt.tight_layout()
plt.show()
In [2]:
import numpy as np
def generate_mixture_samples(means, variances, weights, n):
# Generate random indices to select which distribution to sample from
indices = np.random.choice(len(means), size=n, p=weights)
# Generate random samples from the mixture distribution
samples = np.zeros(n)
for i in range(n):
samples[i] = np.random.normal(means[indices[i]], np.sqrt(variances[indices[i]]))
return samples
In [4]:
import numpy as np
from scipy.stats import norm
def compute_density_ratio(candidate_means, candidate_variances, candidate_weights, true_means, true_variances, true_weights):
# Define the grid of values
grid = np.linspace(-5, 5, 1000)
# Compute the density ratio for each value in the grid
density_ratio = np.zeros_like(grid)
for i, value in enumerate(grid):
candidate_pdf = np.sum([weight * norm.pdf(value, mean, np.sqrt(variance)) for mean, variance, weight in zip(candidate_means, candidate_variances, candidate_weights)])
true_pdf = np.sum([weight * norm.pdf(value, mean, np.sqrt(variance)) for mean, variance, weight in zip(true_means, true_variances, true_weights)])
density_ratio[i] = candidate_pdf / true_pdf
return density_ratio
In [5]:
import seaborn as sns
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import log_loss
q = generate_mixture_samples([1],[1],[1],1000)
dr = compute_density_ratio([1],[1],[1],[0,5],[1,2],[.95,.05])
# Concatenate p and q into a single vector x
x = np.concatenate((p, q))
# Generate vector y
y = np.concatenate((np.ones_like(p), np.zeros_like(q)))
# Create an instance of LogisticRegression with no penalty
logreg = LogisticRegression(penalty='none')
# Fit the logistic regression model
logreg.fit(x.reshape(-1, 1), y)
# Predict the probabilities for the samples
y_pred = logreg.predict_proba(x.reshape(-1, 1))
# Compute the binary cross-entropy loss
loss = -log_loss(y, y_pred)
import matplotlib.pyplot as plt
fig, axs = plt.subplots(1, 3, figsize=(18, 10))
# Plot histogram of p and q samples colored by source
axs[0].hist(p, bins=10, color='skyblue', edgecolor='black', alpha=0.5, label='p')
axs[0].hist(q, bins=10, color='orange', edgecolor='black', alpha=0.5, label='q')
axs[0].set_title('Histogram of Samples')
axs[0].set_xlabel('Value')
axs[0].set_ylabel('Frequency')
axs[0].legend()
# Plot rug plot of p and q samples colored by source
sns.rugplot(p, ax=axs[1], height=0.2, color='skyblue', alpha=0.5, label='Real')
sns.rugplot(q, ax=axs[1], height=0.2, color='orange', alpha=0.5, label='Fake')
axs[1].set_title('Rug Plot of Samples')
axs[1].set_xlabel('Value')
axs[1].set_ylabel('')
axs[1].legend()
# Plot logistic regression curve
x_range = np.linspace(min(x), max(x), 100)
y_range = logreg.predict_proba(x_range.reshape(-1, 1))[:, 1]
axs[1].plot(x_range, y_range, color='red', label='Logistic Regression')
axs[1].legend()
# Print loss in the top left corner
axs[1].text(0.05, 0.95, f'Loss: {loss:.4f}', transform=axs[1].transAxes, verticalalignment='top', bbox=dict(facecolor='white', edgecolor='black', boxstyle='round'))
# Plot dr against np.linspace(-5, 5, 1000)
axs[2].plot(np.linspace(-5, 5, 1000), dr)
axs[2].set_title('Density Ratio')
axs[2].set_xlabel('Value')
axs[2].set_ylabel('Density Ratio')
plt.tight_layout()
plt.show()
/home/kmcalist/.local/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:1183: FutureWarning: `penalty='none'`has been deprecated in 1.2 and will be removed in 1.4. To keep the past behaviour, set `penalty=None`. warnings.warn(
In [6]:
import seaborn as sns
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import log_loss
q = generate_mixture_samples([0],[1],[1],1000)
dr = compute_density_ratio([0],[1],[1],[0,5],[1,2],[.95,.05])
# Concatenate p and q into a single vector x
x = np.concatenate((p, q))
# Generate vector y
y = np.concatenate((np.ones_like(p), np.zeros_like(q)))
# Create an instance of LogisticRegression with no penalty
logreg = LogisticRegression(penalty='none')
# Fit the logistic regression model
logreg.fit(x.reshape(-1, 1), y)
# Predict the probabilities for the samples
y_pred = logreg.predict_proba(x.reshape(-1, 1))
# Compute the binary cross-entropy loss
loss = -log_loss(y, y_pred)
import matplotlib.pyplot as plt
fig, axs = plt.subplots(1, 3, figsize=(18, 10))
# Plot histogram of p and q samples colored by source
axs[0].hist(p, bins=10, color='skyblue', edgecolor='black', alpha=0.5, label='p')
axs[0].hist(q, bins=10, color='orange', edgecolor='black', alpha=0.5, label='q')
axs[0].set_title('Histogram of Samples')
axs[0].set_xlabel('Value')
axs[0].set_ylabel('Frequency')
axs[0].legend()
# Plot rug plot of p and q samples colored by source
sns.rugplot(p, ax=axs[1], height=0.2, color='skyblue', alpha=0.5, label='Real')
sns.rugplot(q, ax=axs[1], height=0.2, color='orange', alpha=0.5, label='Fake')
axs[1].set_title('Rug Plot of Samples')
axs[1].set_xlabel('Value')
axs[1].set_ylabel('')
axs[1].legend()
# Plot logistic regression curve
x_range = np.linspace(min(x), max(x), 100)
y_range = logreg.predict_proba(x_range.reshape(-1, 1))[:, 1]
axs[1].plot(x_range, y_range, color='red', label='Logistic Regression')
axs[1].legend()
# Print loss in the top left corner
axs[1].text(0.05, 0.95, f'Loss: {loss:.4f}', transform=axs[1].transAxes, verticalalignment='top', bbox=dict(facecolor='white', edgecolor='black', boxstyle='round'))
# Plot dr against np.linspace(-5, 5, 1000)
axs[2].plot(np.linspace(-5, 5, 1000), dr)
axs[2].set_title('Density Ratio')
axs[2].set_xlabel('Value')
axs[2].set_ylabel('Density Ratio')
plt.tight_layout()
plt.show()
/home/kmcalist/.local/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:1183: FutureWarning: `penalty='none'`has been deprecated in 1.2 and will be removed in 1.4. To keep the past behaviour, set `penalty=None`. warnings.warn(
In [ ]:
means = np.linspace(-5, 5, 25)
variances = np.linspace(0.001, 3, 25)
import itertools
# Create a cartesian product of means and variances
grid = list(itertools.product(means, variances))
losses = np.zeros((len(means), len(variances)))
for i, (mean, variance) in enumerate(grid):
q = generate_mixture_samples([mean], [variance], [1], 1000)
#dr = compute_density_ratio([mean], [variance], [1], [0, 5], [1, 2], [.95, .05])
x = np.concatenate((p, q))
y = np.concatenate((np.ones_like(p), np.zeros_like(q)))
logreg = LogisticRegression(penalty='none')
logreg.fit(x.reshape(-1, 1), y)
y_pred = logreg.predict_proba(x.reshape(-1, 1))
loss = -log_loss(y, y_pred)
losses[i // len(variances), i % len(variances)] = loss
In [10]:
# Find the point with the minimum loss
min_idx = np.argmin(losses)
min_row, min_col = np.unravel_index(min_idx, losses.shape)
best_mean = means[min_row]
best_variance = variances[min_col]
# Create a meshgrid for contour plotting
X, Y = np.meshgrid(variances, means)
plt.figure(figsize=(8, 6))
contour = plt.contourf(X, Y, losses, levels=20, cmap='viridis')
plt.colorbar(contour)
plt.plot(best_variance, best_mean, 'bX', markersize=15, markeredgewidth=3, label='Min Loss')
plt.xlabel('Variance')
plt.ylabel('Mean')
plt.title('Contour Plot of Losses')
plt.legend()
plt.show()
In [11]:
import seaborn as sns
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import log_loss
q = generate_mixture_samples([0],[2.5],[1],1000)
dr = compute_density_ratio([0],[1],[2.5],[0,5],[1,2],[.95,.05])
# Concatenate p and q into a single vector x
x = np.concatenate((p, q))
# Generate vector y
y = np.concatenate((np.ones_like(p), np.zeros_like(q)))
# Create an instance of LogisticRegression with no penalty
logreg = LogisticRegression(penalty='none')
# Fit the logistic regression model
logreg.fit(x.reshape(-1, 1), y)
# Predict the probabilities for the samples
y_pred = logreg.predict_proba(x.reshape(-1, 1))
# Compute the binary cross-entropy loss
loss = -log_loss(y, y_pred)
import matplotlib.pyplot as plt
fig, axs = plt.subplots(1, 3, figsize=(18, 10))
# Plot histogram of p and q samples colored by source
axs[0].hist(p, bins=10, color='skyblue', edgecolor='black', alpha=0.5, label='p')
axs[0].hist(q, bins=10, color='orange', edgecolor='black', alpha=0.5, label='q')
axs[0].set_title('Histogram of Samples')
axs[0].set_xlabel('Value')
axs[0].set_ylabel('Frequency')
axs[0].legend()
# Plot rug plot of p and q samples colored by source
sns.rugplot(p, ax=axs[1], height=0.2, color='skyblue', alpha=0.5, label='Real')
sns.rugplot(q, ax=axs[1], height=0.2, color='orange', alpha=0.5, label='Fake')
axs[1].set_title('Rug Plot of Samples')
axs[1].set_xlabel('Value')
axs[1].set_ylabel('')
axs[1].legend()
# Plot logistic regression curve
x_range = np.linspace(min(x), max(x), 100)
y_range = logreg.predict_proba(x_range.reshape(-1, 1))[:, 1]
axs[1].plot(x_range, y_range, color='red', label='Logistic Regression')
axs[1].legend()
# Print loss in the top left corner
axs[1].text(0.05, 0.95, f'Loss: {loss:.4f}', transform=axs[1].transAxes, verticalalignment='top', bbox=dict(facecolor='white', edgecolor='black', boxstyle='round'))
# Plot dr against np.linspace(-5, 5, 1000)
axs[2].plot(np.linspace(-5, 5, 1000), dr)
axs[2].set_title('Density Ratio')
axs[2].set_xlabel('Value')
axs[2].set_ylabel('Density Ratio')
plt.tight_layout()
plt.show()
/home/kmcalist/.local/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:1183: FutureWarning: `penalty='none'`has been deprecated in 1.2 and will be removed in 1.4. To keep the past behaviour, set `penalty=None`. warnings.warn(
In [12]:
import seaborn as sns
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import log_loss
q = generate_mixture_samples([0,5],[1,2],[.95,.05],1000)
dr = compute_density_ratio([0,5],[1,2],[.95,.05],[0,5],[1,2],[.95,.05])
# Concatenate p and q into a single vector x
x = np.concatenate((p, q))
# Generate vector y
y = np.concatenate((np.ones_like(p), np.zeros_like(q)))
# Create an instance of LogisticRegression with no penalty
logreg = LogisticRegression(penalty='none')
# Fit the logistic regression model
logreg.fit(x.reshape(-1, 1), y)
# Predict the probabilities for the samples
y_pred = logreg.predict_proba(x.reshape(-1, 1))
# Compute the binary cross-entropy loss
loss = -log_loss(y, y_pred)
import matplotlib.pyplot as plt
fig, axs = plt.subplots(1, 3, figsize=(18, 10))
# Plot histogram of p and q samples colored by source
axs[0].hist(p, bins=10, color='skyblue', edgecolor='black', alpha=0.5, label='p')
axs[0].hist(q, bins=10, color='orange', edgecolor='black', alpha=0.5, label='q')
axs[0].set_title('Histogram of Samples')
axs[0].set_xlabel('Value')
axs[0].set_ylabel('Frequency')
axs[0].legend()
# Plot rug plot of p and q samples colored by source
sns.rugplot(p, ax=axs[1], height=0.2, color='skyblue', alpha=0.5, label='Real')
sns.rugplot(q, ax=axs[1], height=0.2, color='orange', alpha=0.5, label='Fake')
axs[1].set_title('Rug Plot of Samples')
axs[1].set_xlabel('Value')
axs[1].set_ylabel('')
axs[1].legend()
# Plot logistic regression curve
x_range = np.linspace(min(x), max(x), 100)
y_range = logreg.predict_proba(x_range.reshape(-1, 1))[:, 1]
axs[1].plot(x_range, y_range, color='red', label='Logistic Regression')
axs[1].legend()
# Print loss in the top left corner
axs[1].text(0.05, 0.95, f'Loss: {loss:.4f}', transform=axs[1].transAxes, verticalalignment='top', bbox=dict(facecolor='white', edgecolor='black', boxstyle='round'))
# Plot dr against np.linspace(-5, 5, 1000)
axs[2].plot(np.linspace(-5, 5, 1000), dr)
axs[2].set_title('Density Ratio')
axs[2].set_xlabel('Value')
axs[2].set_ylabel('Density Ratio')
plt.tight_layout()
plt.show()
/home/kmcalist/.local/lib/python3.10/site-packages/sklearn/linear_model/_logistic.py:1183: FutureWarning: `penalty='none'`has been deprecated in 1.2 and will be removed in 1.4. To keep the past behaviour, set `penalty=None`. warnings.warn(
Simple f-GAN¶
In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
def sample_real_data(batch_size):
"""
Generate samples from a mixture of 9 (3x3 grid) 2D Gaussians.
Each Gaussian component has a center that is uniformly chosen from:
(-1,-1), (-1,0), (-1,1),
(0,-1), (0,0), (0,1),
(1,-1), (1,0), (1,1).
Each component has a standard deviation.
"""
centers = np.array([[-1, -1], [-1, 0], [-1, 1],
[ 0, -1], [ 0, 0], [ 0, 1],
[ 1, -1], [ 1, 0], [ 1, 1]])
num_components = centers.shape[0]
indices = np.random.choice(num_components, size=batch_size)
chosen_centers = centers[indices]
std = 0.1
samples = chosen_centers + np.random.randn(batch_size, 2) * std
return torch.tensor(samples, dtype=torch.float32)
In [2]:
# Sample 10,000 points from the real data distribution.
samples = sample_real_data(10000)
x = samples[:, 0]
y = samples[:, 1]
# Set up the plot.
plt.figure(figsize=(8, 8))
# Use seaborn's kdeplot to create a density plot.
sns.kdeplot(x=x, y=y, cmap="viridis", fill=True, thresh=0, levels=100)
plt.title("Density of Real Data (10,000 Samples)")
plt.xlabel("x")
plt.ylabel("y")
plt.show()
In [3]:
# -----------------------
# 1. Define a generic MLP (like the JAX MLP class)
# -----------------------
class MLP(nn.Module):
def __init__(self, input_dim, features):
"""
Constructs an MLP that applies a linear layer followed by ReLU for all but the final layer.
Args:
input_dim: Dimension of the input.
features: List of integers, where each element corresponds to the output size of a dense layer.
The final element does not get a nonlinearity.
"""
super(MLP, self).__init__()
layers = []
in_dim = input_dim
for i, out_dim in enumerate(features):
layers.append(nn.Linear(in_dim, out_dim))
if i < len(features) - 1:
layers.append(nn.ReLU(inplace=True))
in_dim = out_dim
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
In [4]:
# -----------------------
# 3. Define loss functions using log-sigmoid (replicates the JAX loss computation)
# -----------------------
def log_sigmoid(x):
# Returns log(sigmoid(x))
return torch.log(torch.sigmoid(x))
def discriminator_loss(D, G, real_examples, latents):
"""
Computes the discriminator loss as:
loss = mean( -log_sigmoid(D(real)) - log_sigmoid(-D(G(latents))) )
"""
real_logits = D(real_examples)
fake_examples = G(latents)
fake_logits = D(fake_examples)
loss_real = - log_sigmoid(real_logits)
loss_fake = - log_sigmoid(-fake_logits)
return torch.mean(loss_real + loss_fake)
def generator_loss(D, G, latents):
"""
Computes the generator loss as:
loss = mean( - log_sigmoid(D(G(latents))) )
"""
fake_examples = G(latents)
fake_logits = D(fake_examples)
loss = - log_sigmoid(fake_logits)
return torch.mean(loss)
In [5]:
# -----------------------
# 4. Training loop replicating the JAX training steps with SGD
# -----------------------
def train_gan(num_iters=20001, batch_size=512, latent_size=32, lr=0.05, n_save=2000, draw_contours=False, device='cpu'):
device = torch.device(device)
# Create the discriminator and generator.
# Discriminator: input_dim 2, hidden layers: 25, 25, output_dim: 1
# Generator: input_dim latent_size, hidden layers: 25, 25, output_dim: 2
D = MLP(input_dim=2, features=[128, 128, 128, 1]).to(device)
G = MLP(input_dim=latent_size, features=[128,128,128,2]).to(device)
# Set up SGD optimizers replicating the JAX SGD with lr=0.05.
optimizer_D = optim.SGD(D.parameters(), lr=lr)
optimizer_G = optim.SGD(G.parameters(), lr=lr)
# Prepare a fixed test latent vector for evaluation (10,000 samples)
test_latents = torch.randn(10000, latent_size, device=device)
history = [] # will store tuples: (iteration, fake_examples, disc_contour, disc_loss, gen_loss)
for i in range(num_iters):
# Sample minibatch of real examples (shape: [batch_size, 2])
real_examples = sample_real_data(batch_size).to(device)
# Sample minibatch of latent vectors from a standard normal (shape: [batch_size, latent_size])
latents = torch.randn(batch_size, latent_size, device=device)
# -- Discriminator step --
optimizer_D.zero_grad()
loss_D = discriminator_loss(D, G, real_examples, latents)
loss_D.backward()
optimizer_D.step()
# -- Generator step --
optimizer_G.zero_grad()
# We use the same minibatch of latents here.
loss_G = generator_loss(D, G, latents)
loss_G.backward()
optimizer_G.step()
if i % n_save == 0:
print(f"i = {i}, Discriminator Loss = {loss_D.item()}, Generator Loss = {loss_G.item()}")
with torch.no_grad():
fake_examples = G(test_latents)
disc_contour = None
if draw_contours:
# Optional: compute a contour measure over some grid if desired.
# (The original code computes: -D(pairs) + log_sigmoid(D(pairs)))
# For simplicity, we leave this as None.
disc_contour = None
history.append((i, fake_examples.cpu(), disc_contour, loss_D.item(), loss_G.item()))
return D, G, history
In [6]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
D, G, history = train_gan(num_iters=20001, batch_size=512, latent_size=32, lr=0.05,
n_save=2000, draw_contours=False, device=device)
i = 0, Discriminator Loss = 1.384421944618225, Generator Loss = 0.654618501663208 i = 2000, Discriminator Loss = 1.2403438091278076, Generator Loss = 0.6935384273529053 i = 4000, Discriminator Loss = 1.0644993782043457, Generator Loss = 1.0018454790115356 i = 6000, Discriminator Loss = 1.0645021200180054, Generator Loss = 1.893606424331665 i = 8000, Discriminator Loss = 0.7953554391860962, Generator Loss = 1.5759837627410889 i = 10000, Discriminator Loss = 0.8092895746231079, Generator Loss = 1.2489733695983887 i = 12000, Discriminator Loss = 0.8678987622261047, Generator Loss = 1.3148552179336548 i = 14000, Discriminator Loss = 0.9495569467544556, Generator Loss = 1.1578044891357422 i = 16000, Discriminator Loss = 0.8955338001251221, Generator Loss = 1.198028564453125 i = 18000, Discriminator Loss = 0.8887498378753662, Generator Loss = 1.2266111373901367 i = 20000, Discriminator Loss = 0.9115414619445801, Generator Loss = 1.2572792768478394
In [7]:
import matplotlib.pyplot as plt
import seaborn as sns
# Assuming 'history' is available from training
# Each element in history is a tuple: (iteration, fake_samples, disc_loss, gen_loss)
for entry in history:
iteration, fake_samples, disc_contour, disc_loss, gen_loss = entry
# Create a figure for each snapshot
plt.figure(figsize=(6, 6))
# Use Seaborn's kdeplot to compute and display the 2D kernel density estimate.
sns.kdeplot(x=fake_samples[:, 0], y=fake_samples[:, 1],
fill=True, levels=50, cmap="viridis")
# Add labels and a title with iteration and losses information.
plt.xlabel("x")
plt.ylabel("y")
plt.title(f"Estimated Density at Iteration {iteration}\n"
f"Discriminator Loss: {disc_loss:.4f}, Generator Loss: {gen_loss:.4f}")
plt.tight_layout()
plt.show()
Simple w-GAN¶
In [20]:
class MLP(nn.Module):
def __init__(self, input_dim, features):
"""
Constructs an MLP with hidden layers specified by the list `features`.
A ReLU activation is applied after each layer except the final one.
"""
super(MLP, self).__init__()
layers = []
in_dim = input_dim
for i, out_dim in enumerate(features):
layers.append(nn.Linear(in_dim, out_dim))
if i < len(features) - 1:
layers.append(nn.ReLU(inplace=True))
in_dim = out_dim
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
# -----------------------
# 3. Define Wasserstein Losses
# -----------------------
def critic_loss(D, G, real_examples, latents):
"""
Computes the Wasserstein critic loss:
L_D = -(E[D(real)] - E[D(fake)])
"""
real_scores = D(real_examples)
fake_scores = D(G(latents))
return fake_scores.mean() - real_scores.mean()
def generator_loss(D, G, latents):
"""
Computes the generator loss for WGAN:
L_G = -E[D(fake)]
"""
return - D(G(latents)).mean()
In [21]:
# -----------------------
# 4. Training Loop for WGAN with n_disc updates per iteration
# -----------------------
def train_wgan(num_iters=20001, batch_size=512, latent_size=32, lr=0.05,
n_save=2000, n_disc=5, clip_value=0.01, device='cpu', draw_contours = False):
device = torch.device(device)
# Instantiate the critic and generator.
D = MLP(input_dim=2, features=[128,128,128, 1]).to(device) # Critic: 2D -> score (no sigmoid!)
G = MLP(input_dim=latent_size, features=[128,128,128, 2]).to(device) # Generator: latent -> 2D output
# Set up optimizers using SGD as in the original JAX code.
optimizer_D = optim.SGD(D.parameters(), lr=lr)
optimizer_G = optim.SGD(G.parameters(), lr=lr)
# Fixed test latents for monitoring (10,000 samples)
test_latents = torch.randn(10000, latent_size, device=device)
history = [] # List to store snapshots: (iteration, fake_samples, critic_loss, generator_loss)
for i in range(num_iters):
# --- Critic (Discriminator) update: n_disc iterations ---
for _ in range(n_disc):
real_examples = sample_real_data(batch_size).to(device)
latents = torch.randn(batch_size, latent_size, device=device)
optimizer_D.zero_grad()
loss_D = critic_loss(D, G, real_examples, latents)
loss_D.backward()
optimizer_D.step()
# Weight clipping to enforce Lipschitz condition.
for p in D.parameters():
p.data.clamp_(-clip_value, clip_value)
# --- Generator update (one step after n_disc updates) ---
latents = torch.randn(batch_size, latent_size, device=device)
optimizer_G.zero_grad()
loss_G = generator_loss(D, G, latents)
loss_G.backward()
optimizer_G.step()
if i % n_save == 0:
print(f"i = {i}, Discriminator Loss = {loss_D.item()}, Generator Loss = {loss_G.item()}")
with torch.no_grad():
fake_examples = G(test_latents)
disc_contour = None
if draw_contours:
# Optional: compute a contour measure over some grid if desired.
# (The original code computes: -D(pairs) + log_sigmoid(D(pairs)))
# For simplicity, we leave this as None.
disc_contour = None
history.append((i, fake_examples.cpu(), disc_contour, loss_D.item(), loss_G.item()))
return D, G, history
In [22]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
D, G, history = train_wgan(num_iters=20001, batch_size=10000, latent_size=32, lr=0.05,
n_save=2000, n_disc=2, clip_value=.01, device=device)
i = 0, Discriminator Loss = -4.794273991137743e-07, Generator Loss = 0.00012903052265755832 i = 2000, Discriminator Loss = -1.542569589219056e-05, Generator Loss = 4.844953946303576e-05 i = 4000, Discriminator Loss = -4.3357867980375886e-05, Generator Loss = -3.837565236608498e-05 i = 6000, Discriminator Loss = -8.668069494888186e-05, Generator Loss = -0.0002593372482806444 i = 8000, Discriminator Loss = -0.00016360956942662597, Generator Loss = -0.0002497853129170835 i = 10000, Discriminator Loss = -0.0002925149165093899, Generator Loss = -0.0002160591830033809 i = 12000, Discriminator Loss = -0.0005131656071171165, Generator Loss = -0.0001401335612172261 i = 14000, Discriminator Loss = -0.0008085581357590854, Generator Loss = -7.504215318476781e-05 i = 16000, Discriminator Loss = -0.0011433582985773683, Generator Loss = -9.464036338613369e-06 i = 18000, Discriminator Loss = -0.0014178385026752949, Generator Loss = 2.4077523903542897e-06 i = 20000, Discriminator Loss = -0.0015906720655038953, Generator Loss = -1.5236424587783404e-05
In [23]:
import matplotlib.pyplot as plt
import seaborn as sns
# Assuming 'history' is available from training
# Each element in history is a tuple: (iteration, fake_samples, disc_loss, gen_loss)
for entry in history:
iteration, fake_samples, disc_contour, disc_loss, gen_loss = entry
# Create a figure for each snapshot
plt.figure(figsize=(6, 6))
# Use Seaborn's kdeplot to compute and display the 2D kernel density estimate.
sns.kdeplot(x=fake_samples[:, 0], y=fake_samples[:, 1],
fill=True, levels=50, cmap="viridis")
# Add labels and a title with iteration and losses information.
plt.xlabel("x")
plt.ylabel("y")
plt.title(f"Estimated Density at Iteration {iteration}\n"
f"Discriminator Loss: {disc_loss:.4f}, Generator Loss: {gen_loss:.4f}")
plt.tight_layout()
plt.show()
WGAN-GP¶
In [24]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
# -----------------------
# 1. Define a generic MLP (used for both critic and generator)
# -----------------------
class MLP(nn.Module):
def __init__(self, input_dim, features):
"""
Constructs an MLP with hidden layers specified by the list `features`.
A ReLU activation is applied after each layer except the final one.
"""
super(MLP, self).__init__()
layers = []
in_dim = input_dim
for i, out_dim in enumerate(features):
layers.append(nn.Linear(in_dim, out_dim))
if i < len(features) - 1:
layers.append(nn.ReLU(inplace=True))
in_dim = out_dim
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
# -----------------------
# 2. Real Data Distribution: 9-component Mixture of 2D Gaussians
# -----------------------
def sample_real_data(batch_size, std=0.1):
"""
Samples a batch of 2D points drawn from one of 9 Gaussian components arranged on a 3x3 grid.
The centers are:
(-1,-1), (-1, 0), (-1, 1),
( 0,-1), ( 0, 0), ( 0, 1),
( 1,-1), ( 1, 0), ( 1, 1)
"""
centers = np.array([[-1, -1], [-1, 0], [-1, 1],
[ 0, -1], [ 0, 0], [ 0, 1],
[ 1, -1], [ 1, 0], [ 1, 1]])
num_components = centers.shape[0]
indices = np.random.choice(num_components, size=batch_size)
chosen_centers = centers[indices]
samples = chosen_centers + np.random.randn(batch_size, 2) * std
return torch.tensor(samples, dtype=torch.float)
# -----------------------
# 3. Define Gradient Penalty and Wasserstein Losses (WGAN-GP)
# -----------------------
def gradient_penalty(D, real_samples, fake_samples, device):
"""Computes the gradient penalty for interpolated samples."""
batch_size = real_samples.size(0)
# Sample interpolation coefficient uniformly between 0 and 1.
alpha = torch.rand(batch_size, 1, device=device)
alpha = alpha.expand_as(real_samples)
# Create interpolated samples.
interpolates = alpha * real_samples + (1 - alpha) * fake_samples
interpolates.requires_grad_(True)
# Compute critic scores on interpolated samples.
d_interpolates = D(interpolates)
# For each sample, create a tensor of ones with the same shape as the output.
ones = torch.ones(d_interpolates.size(), device=device)
# Compute gradients of critic scores with respect to the interpolated samples.
gradients = torch.autograd.grad(outputs=d_interpolates, inputs=interpolates,
grad_outputs=ones,
create_graph=True,
retain_graph=True,
only_inputs=True)[0]
# Flatten gradients per sample.
gradients = gradients.view(batch_size, -1)
gradient_norm = gradients.norm(2, dim=1)
# Compute penalty as (gradient norm - 1)^2.
gp = ((gradient_norm - 1) ** 2).mean()
return gp
def critic_loss(D, G, real_examples, latents, lambda_gp, device):
"""
Computes the WGAN critic loss with gradient penalty:
L_D = E[D(fake)] - E[D(real)] + lambda_gp * GP,
where GP is the gradient penalty.
"""
fake_examples = G(latents)
real_scores = D(real_examples)
fake_scores = D(fake_examples)
loss = fake_scores.mean() - real_scores.mean()
# Compute gradient penalty on interpolated samples.
gp = gradient_penalty(D, real_examples, fake_examples, device)
loss += lambda_gp * gp
return loss
def generator_loss(D, G, latents):
"""
Computes the generator loss for WGAN:
L_G = -E[D(G(latents))]
"""
return - D(G(latents)).mean()
# -----------------------
# 4. Training Loop for WGAN-GP with n_disc updates per iteration
# -----------------------
def train_wgan_gp(num_iters=20001, batch_size=512, latent_size=32, lr=0.05,
n_save=2000, n_disc=5, lambda_gp=10, device='cpu', draw_contours = False):
device = torch.device(device)
# Instantiate the critic and generator.
D = MLP(input_dim=2, features=[128,128,128,1]).to(device) # Critic: 2D -> score (no activation)
G = MLP(input_dim=latent_size, features=[128, 128,128, 2]).to(device) # Generator: latent -> 2D output
# Set up optimizers (using SGD as in the JAX code).
optimizer_D = optim.SGD(D.parameters(), lr=lr)
optimizer_G = optim.SGD(G.parameters(), lr=lr)
# Fixed test latent vectors for monitoring (10,000 samples).
test_latents = torch.randn(10000, latent_size, device=device)
history = [] # To store snapshots: (iteration, fake_samples, critic_loss, generator_loss)
for i in range(num_iters):
# --- Critic (Discriminator) update: perform n_disc updates ---
for _ in range(n_disc):
real_examples = sample_real_data(batch_size).to(device)
latents = torch.randn(batch_size, latent_size, device=device)
optimizer_D.zero_grad()
loss_D = critic_loss(D, G, real_examples, latents, lambda_gp, device)
loss_D.backward()
optimizer_D.step()
# --- Generator update (one update after n_disc critic updates) ---
latents = torch.randn(batch_size, latent_size, device=device)
optimizer_G.zero_grad()
loss_G = generator_loss(D, G, latents)
loss_G.backward()
optimizer_G.step()
if i % n_save == 0:
print(f"i = {i}, Discriminator Loss = {loss_D.item()}, Generator Loss = {loss_G.item()}")
with torch.no_grad():
fake_examples = G(test_latents)
disc_contour = None
if draw_contours:
# Optional: compute a contour measure over some grid if desired.
# (The original code computes: -D(pairs) + log_sigmoid(D(pairs)))
# For simplicity, we leave this as None.
disc_contour = None
history.append((i, fake_examples.cpu(), disc_contour, loss_D.item(), loss_G.item()))
return D, G, history
In [29]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
D, G, history = train_wgan_gp(num_iters=20001, batch_size=10000, latent_size=32, lr=0.05,
n_save=2000, n_disc=2, lambda_gp=1, device=device)
i = 0, Discriminator Loss = 0.9039061069488525, Generator Loss = 0.1036360114812851 i = 2000, Discriminator Loss = -0.1359843909740448, Generator Loss = 5.910412788391113 i = 4000, Discriminator Loss = -0.07140640914440155, Generator Loss = 4.818828582763672 i = 6000, Discriminator Loss = -0.02494090050458908, Generator Loss = 4.749511241912842 i = 8000, Discriminator Loss = -0.004812396131455898, Generator Loss = 4.237672328948975 i = 10000, Discriminator Loss = -0.0016526570543646812, Generator Loss = 3.8261005878448486 i = 12000, Discriminator Loss = -0.002830632496625185, Generator Loss = 3.5270280838012695 i = 14000, Discriminator Loss = 0.014997678808867931, Generator Loss = 3.673243761062622 i = 16000, Discriminator Loss = 0.0187930129468441, Generator Loss = 3.8711085319519043 i = 18000, Discriminator Loss = 0.023033861070871353, Generator Loss = 4.0552520751953125 i = 20000, Discriminator Loss = 0.026469510048627853, Generator Loss = 4.31480073928833
In [30]:
import matplotlib.pyplot as plt
import seaborn as sns
# Assuming 'history' is available from training
# Each element in history is a tuple: (iteration, fake_samples, disc_loss, gen_loss)
for entry in history:
iteration, fake_samples, disc_contour, disc_loss, gen_loss = entry
# Create a figure for each snapshot
plt.figure(figsize=(6, 6))
# Use Seaborn's kdeplot to compute and display the 2D kernel density estimate.
sns.kdeplot(x=fake_samples[:, 0], y=fake_samples[:, 1],
fill=True, levels=50, cmap="viridis")
# Add labels and a title with iteration and losses information.
plt.xlabel("x")
plt.ylabel("y")
plt.title(f"Estimated Density at Iteration {iteration}\n"
f"Discriminator Loss: {disc_loss:.4f}, Generator Loss: {gen_loss:.4f}")
plt.tight_layout()
plt.show()
f-GAN for Celebrity Faces (DCGAN)¶
In [1]:
import os
import pandas as pd
# Specify the folder containing your subset of images and the CSV file path
images_folder = 'subset_images'
csv_path = 'list_attr_celeba.csv'
# Read the CSV file into a DataFrame
df = pd.read_csv(csv_path)
# Get a set of image filenames present in the subset_images folder
existing_images = set(os.listdir(images_folder))
# Filter the DataFrame to only include rows where the image_id exists in the folder
df_subset = df[df['image_id'].isin(existing_images)]
df_subset = df_subset.reset_index(drop=True)
print("Total rows in original CSV:", len(df))
print("Total rows in subset CSV:", len(df_subset))
# Optionally, save the subset to a new CSV file for future use:
df_subset.to_csv("subset_list_attr_celeba.csv", index=False)
# Get attribute names from the CSV (all columns except the first "image_id")
attribute_names = list(df_subset.columns[1:])
print("Attribute Names:")
print(attribute_names)
Total rows in original CSV: 202599 Total rows in subset CSV: 15000 Attribute Names: ['5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes', 'Bald', 'Bangs', 'Big_Lips', 'Big_Nose', 'Black_Hair', 'Blond_Hair', 'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin', 'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones', 'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard', 'Oval_Face', 'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline', 'Rosy_Cheeks', 'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair', 'Wearing_Earrings', 'Wearing_Hat', 'Wearing_Lipstick', 'Wearing_Necklace', 'Wearing_Necktie', 'Young']
In [2]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, random_split
class CelebASubsetDataset(Dataset):
def __init__(self, images_dir, csv_file, transform=None):
"""
Args:
images_dir (str): Path to the directory containing the images.
csv_file (str): Path to the CSV file with image attributes.
transform (callable, optional): Optional transform to be applied on an image.
"""
self.images_dir = images_dir
self.transform = transform
# Read the CSV file into a DataFrame
self.attr_df = pd.read_csv(csv_file)
# Assuming the first column is 'image_id' and the rest are attributes, we store the attribute names.
self.image_ids = self.attr_df['image_id'].values
# Get attribute columns (all columns besides 'image_id')
self.attributes = self.attr_df.drop(columns=['image_id']).values.astype('float32')
def __len__(self):
return len(self.image_ids)
def __getitem__(self, idx):
# Get the image file name and its corresponding attributes
img_id = self.image_ids[idx]
attr = self.attributes[idx]
# Construct the full image path
img_path = os.path.join(self.images_dir, img_id)
# Open the image file and ensure it is in RGB mode.
image = Image.open(img_path).convert('RGB')
if self.transform:
image = self.transform(image)
# Convert the attributes to a tensor
attr_tensor = torch.tensor(attr)
return image, attr_tensor
# Update transform: Resize to 64x64, then ToTensor, and normalize to [-1, 1]
transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# Create the full dataset instance.
dataset = CelebASubsetDataset(
images_dir='subset_images',
csv_file='subset_list_attr_celeba.csv',
transform=transform
)
# Define a split proportion for training and validation.
# For example, an 80/20 split:
train_size = int(0.999 * len(dataset))
valid_size = len(dataset) - train_size
# Split the dataset using random_split.
train_dataset, valid_dataset = random_split(dataset, [train_size, valid_size])
# Create DataLoaders for both splits.
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
valid_dataloader = DataLoader(valid_dataset, batch_size=32, shuffle=False, num_workers=4)
# Optional: Print out the sizes for confirmation.
print(f"Total dataset size: {len(dataset)}")
print(f"Training set size: {len(train_dataset)}")
print(f"Validation set size: {len(valid_dataset)}")
Total dataset size: 15000 Training set size: 14985 Validation set size: 15
In [3]:
# =======================================================
# Define a helper module for reshaping in the Generator
# =======================================================
class View(nn.Module):
def __init__(self, shape):
"""
A simple layer to reshape tensors to the given shape.
"""
super(View, self).__init__()
self.shape = shape
def forward(self, input):
return input.view(*self.shape)
In [4]:
# =======================================================
# Generator: DCGAN-Style for 64x64 Images
# =======================================================
class Generator(nn.Module):
def __init__(self, nz=100, ngf=64):
"""
Generator mapping a latent vector (dimension nz) to a 3x64x64 image.
It first projects the latent vector to a tensor of shape (ngf*8, 8, 8),
then upsamples through three layers:
8x8 -> 16x16,
16x16 -> 32x32,
32x32 -> 64x64.
The final convolution produces a 3-channel image.
"""
super(Generator, self).__init__()
self.nz = nz
self.main = nn.Sequential(
# Project latent vector and reshape:
nn.Linear(nz, ngf * 8 * 8 * 8),
nn.BatchNorm1d(ngf * 8 * 8 * 8),
nn.ReLU(True),
View((-1, ngf * 8, 8, 8)), # Shape: (ngf*8, 8, 8)
# Upsample: 8x8 -> 16x16
nn.ConvTranspose2d(ngf * 8, ngf * 4, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# Upsample: 16x16 -> 32x32
nn.ConvTranspose2d(ngf * 4, ngf * 2, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# Upsample: 32x32 -> 64x64
nn.ConvTranspose2d(ngf * 2, 3, kernel_size=4, stride=2, padding=1, bias=False),
nn.Tanh() # Output values in the range [-1,1]
)
def forward(self, input):
return self.main(input)
In [5]:
# =======================================================
# Discriminator: DCGAN-Style for 64x64 Images
# =======================================================
class Discriminator(nn.Module):
def __init__(self, ndf=64):
"""
The discriminator takes a 3x64x64 image and outputs a scalar probability.
It uses four convolutional layers:
- 64x64 -> 32x32,
- 32x32 -> 16x16,
- 16x16 -> 8x8,
- 8x8 -> 4x4.
A final convolution with kernel size 4 collapses the 4x4 feature map to a single value.
"""
super(Discriminator, self).__init__()
self.main = nn.Sequential(
# Input: 3 x 64 x 64 -> ndf x 32 x 32
nn.Conv2d(3, ndf, kernel_size=4, stride=2, padding=1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# 32x32 -> 16x16; output channels: ndf*2
nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# 16x16 -> 8x8; output channels: ndf*4
nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# 8x8 -> 4x4; output channels: ndf*8
nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
# Final layer: 4x4 -> 1, using kernel size 4
nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=1, padding=0, bias=False),
nn.Sigmoid() # For probability output
)
def forward(self, input):
out = self.main(input)
return out.view(-1)
In [6]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from PIL import Image
from torchvision import transforms, utils
from torch.utils.data import Dataset, DataLoader, random_split
import matplotlib.pyplot as plt
# =======================================================
# Training Procedure using BCE Loss
# =======================================================
def train_f_gan(num_epochs=50, nz=100, device='cpu'):
device = torch.device(device)
# Instantiate generator and discriminator.
G = Generator(nz=nz, ngf=64).to(device)
D = Discriminator(ndf=64).to(device)
# Use Adam optimizers with typical parameters for DCGAN.
optimizerD = optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerG = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
# Use binary cross entropy loss.
criterion = nn.BCELoss()
# Create a fixed set of latent vectors (25 images) for visualization.
fixed_noise = torch.randn(25, nz, device=device)
# Variables to store last batch losses for display.
last_loss_D = None
last_loss_G = None
for epoch in range(num_epochs):
for i, (images, _) in enumerate(train_dataloader):
images = images.to(device) # [B, 3, 224, 224]
batch_size = images.size(0)
# Create labels.
real_labels = torch.ones(batch_size, device=device)
fake_labels = torch.zeros(batch_size, device=device)
# -------------------------
# Train Discriminator
# -------------------------
optimizerD.zero_grad()
# Real images forward-pass.
outputs_real = D(images)
loss_real = criterion(outputs_real, real_labels)
# Generate fake images.
noise = torch.randn(batch_size, nz, device=device)
fake_images = G(noise)
outputs_fake = D(fake_images.detach())
loss_fake = criterion(outputs_fake, fake_labels)
loss_D = loss_real + loss_fake
loss_D.backward()
optimizerD.step()
# -------------------------
# Train Generator
# -------------------------
optimizerG.zero_grad()
# Generate fake images; we want these to be classified as real.
noise = torch.randn(batch_size, nz, device=device)
fake_images = G(noise)
outputs = D(fake_images)
loss_G = criterion(outputs, real_labels)
loss_G.backward()
optimizerG.step()
if i % 50 == 0:
print(f"Epoch [{epoch}/{num_epochs}], Batch [{i}/{len(train_dataloader)}], "
f"Loss_D: {loss_D.item():.4f}, Loss_G: {loss_G.item():.4f}")
last_loss_D = loss_D.item()
last_loss_G = loss_G.item()
# End of each epoch: generate 25 images from fixed latent vectors and display them.
with torch.no_grad():
fake_samples = G(fixed_noise).detach().cpu()
# Create a grid of 25 images.
grid = utils.make_grid(fake_samples, nrow=5, normalize=True)
plt.figure(figsize=(8, 8))
plt.imshow(grid.permute(1, 2, 0).numpy())
plt.axis('off')
plt.title(f"Epoch {epoch} | Loss_D: {last_loss_D:.4f} | Loss_G: {last_loss_G:.4f}")
plt.show()
plt.close()
print(f"End of Epoch {epoch} completed.")
return D, G
In [7]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
D, G = train_f_gan(num_epochs=20, nz=100, device=device)
Epoch [0/20], Batch [0/235], Loss_D: 1.4343, Loss_G: 1.4272 Epoch [0/20], Batch [50/235], Loss_D: 0.0313, Loss_G: 8.7729 Epoch [0/20], Batch [100/235], Loss_D: 0.0303, Loss_G: 14.2098 Epoch [0/20], Batch [150/235], Loss_D: 0.2522, Loss_G: 7.2179 Epoch [0/20], Batch [200/235], Loss_D: 0.4485, Loss_G: 6.5541
End of Epoch 0 completed. Epoch [1/20], Batch [0/235], Loss_D: 0.2547, Loss_G: 6.2205 Epoch [1/20], Batch [50/235], Loss_D: 0.1791, Loss_G: 3.9664 Epoch [1/20], Batch [100/235], Loss_D: 0.0452, Loss_G: 5.0688 Epoch [1/20], Batch [150/235], Loss_D: 0.3047, Loss_G: 6.7193 Epoch [1/20], Batch [200/235], Loss_D: 0.0989, Loss_G: 4.0727
End of Epoch 1 completed. Epoch [2/20], Batch [0/235], Loss_D: 0.0806, Loss_G: 4.5329 Epoch [2/20], Batch [50/235], Loss_D: 0.1334, Loss_G: 5.1288 Epoch [2/20], Batch [100/235], Loss_D: 0.1641, Loss_G: 3.6295 Epoch [2/20], Batch [150/235], Loss_D: 1.7095, Loss_G: 1.4473 Epoch [2/20], Batch [200/235], Loss_D: 0.1942, Loss_G: 4.6550
End of Epoch 2 completed. Epoch [3/20], Batch [0/235], Loss_D: 0.0985, Loss_G: 3.7031 Epoch [3/20], Batch [50/235], Loss_D: 0.1984, Loss_G: 3.6030 Epoch [3/20], Batch [100/235], Loss_D: 0.4367, Loss_G: 10.5271 Epoch [3/20], Batch [150/235], Loss_D: 0.1521, Loss_G: 2.9989 Epoch [3/20], Batch [200/235], Loss_D: 0.1850, Loss_G: 3.4543
End of Epoch 3 completed. Epoch [4/20], Batch [0/235], Loss_D: 0.4477, Loss_G: 1.8774 Epoch [4/20], Batch [50/235], Loss_D: 0.2037, Loss_G: 4.1418 Epoch [4/20], Batch [100/235], Loss_D: 0.0637, Loss_G: 3.5348 Epoch [4/20], Batch [150/235], Loss_D: 0.6581, Loss_G: 1.7092 Epoch [4/20], Batch [200/235], Loss_D: 0.1727, Loss_G: 2.8104
End of Epoch 4 completed. Epoch [5/20], Batch [0/235], Loss_D: 0.1548, Loss_G: 3.2654 Epoch [5/20], Batch [50/235], Loss_D: 0.7366, Loss_G: 5.2767 Epoch [5/20], Batch [100/235], Loss_D: 0.0485, Loss_G: 4.6805 Epoch [5/20], Batch [150/235], Loss_D: 0.2277, Loss_G: 4.8771 Epoch [5/20], Batch [200/235], Loss_D: 0.0831, Loss_G: 3.9929
End of Epoch 5 completed. Epoch [6/20], Batch [0/235], Loss_D: 0.0328, Loss_G: 4.9790 Epoch [6/20], Batch [50/235], Loss_D: 0.1184, Loss_G: 3.6737 Epoch [6/20], Batch [100/235], Loss_D: 0.0574, Loss_G: 4.9615 Epoch [6/20], Batch [150/235], Loss_D: 0.4830, Loss_G: 1.6109 Epoch [6/20], Batch [200/235], Loss_D: 1.1322, Loss_G: 11.0200
End of Epoch 6 completed. Epoch [7/20], Batch [0/235], Loss_D: 3.3290, Loss_G: 7.9481 Epoch [7/20], Batch [50/235], Loss_D: 0.0750, Loss_G: 3.3906 Epoch [7/20], Batch [100/235], Loss_D: 0.0546, Loss_G: 4.9041 Epoch [7/20], Batch [150/235], Loss_D: 0.0656, Loss_G: 5.2365 Epoch [7/20], Batch [200/235], Loss_D: 0.2962, Loss_G: 4.2396
End of Epoch 7 completed. Epoch [8/20], Batch [0/235], Loss_D: 0.1624, Loss_G: 3.8747 Epoch [8/20], Batch [50/235], Loss_D: 0.1166, Loss_G: 2.9114 Epoch [8/20], Batch [100/235], Loss_D: 6.3864, Loss_G: 7.3692 Epoch [8/20], Batch [150/235], Loss_D: 0.2464, Loss_G: 3.1736 Epoch [8/20], Batch [200/235], Loss_D: 0.1069, Loss_G: 3.8590
End of Epoch 8 completed. Epoch [9/20], Batch [0/235], Loss_D: 0.0554, Loss_G: 3.5248 Epoch [9/20], Batch [50/235], Loss_D: 0.1436, Loss_G: 3.4984 Epoch [9/20], Batch [100/235], Loss_D: 0.0329, Loss_G: 6.2885 Epoch [9/20], Batch [150/235], Loss_D: 0.4614, Loss_G: 4.2605 Epoch [9/20], Batch [200/235], Loss_D: 0.4134, Loss_G: 3.6515
End of Epoch 9 completed. Epoch [10/20], Batch [0/235], Loss_D: 0.0390, Loss_G: 4.0490 Epoch [10/20], Batch [50/235], Loss_D: 0.0531, Loss_G: 3.9692 Epoch [10/20], Batch [100/235], Loss_D: 0.0499, Loss_G: 5.3644 Epoch [10/20], Batch [150/235], Loss_D: 0.4878, Loss_G: 3.4199 Epoch [10/20], Batch [200/235], Loss_D: 0.8591, Loss_G: 6.4835
End of Epoch 10 completed. Epoch [11/20], Batch [0/235], Loss_D: 0.4459, Loss_G: 2.0155 Epoch [11/20], Batch [50/235], Loss_D: 0.8302, Loss_G: 1.0756 Epoch [11/20], Batch [100/235], Loss_D: 0.1194, Loss_G: 3.5714 Epoch [11/20], Batch [150/235], Loss_D: 0.1289, Loss_G: 4.6166 Epoch [11/20], Batch [200/235], Loss_D: 0.0646, Loss_G: 4.2834
End of Epoch 11 completed. Epoch [12/20], Batch [0/235], Loss_D: 0.1541, Loss_G: 6.3816 Epoch [12/20], Batch [50/235], Loss_D: 0.1974, Loss_G: 5.4419 Epoch [12/20], Batch [100/235], Loss_D: 0.8444, Loss_G: 1.1366 Epoch [12/20], Batch [150/235], Loss_D: 0.1047, Loss_G: 3.8173 Epoch [12/20], Batch [200/235], Loss_D: 0.0680, Loss_G: 4.3182
End of Epoch 12 completed. Epoch [13/20], Batch [0/235], Loss_D: 1.0144, Loss_G: 4.6319 Epoch [13/20], Batch [50/235], Loss_D: 0.0882, Loss_G: 3.3174 Epoch [13/20], Batch [100/235], Loss_D: 0.0852, Loss_G: 4.0259 Epoch [13/20], Batch [150/235], Loss_D: 0.0449, Loss_G: 4.9833 Epoch [13/20], Batch [200/235], Loss_D: 0.0284, Loss_G: 6.1240
End of Epoch 13 completed. Epoch [14/20], Batch [0/235], Loss_D: 0.0258, Loss_G: 6.4724 Epoch [14/20], Batch [50/235], Loss_D: 0.0082, Loss_G: 5.8401 Epoch [14/20], Batch [100/235], Loss_D: 0.3763, Loss_G: 3.7295 Epoch [14/20], Batch [150/235], Loss_D: 0.1908, Loss_G: 2.3354 Epoch [14/20], Batch [200/235], Loss_D: 0.3295, Loss_G: 2.6553
End of Epoch 14 completed. Epoch [15/20], Batch [0/235], Loss_D: 0.1220, Loss_G: 3.8772 Epoch [15/20], Batch [50/235], Loss_D: 0.1942, Loss_G: 3.5176 Epoch [15/20], Batch [100/235], Loss_D: 0.1182, Loss_G: 3.3420 Epoch [15/20], Batch [150/235], Loss_D: 0.1327, Loss_G: 3.4660 Epoch [15/20], Batch [200/235], Loss_D: 1.4664, Loss_G: 10.9529
End of Epoch 15 completed. Epoch [16/20], Batch [0/235], Loss_D: 0.1515, Loss_G: 3.7957 Epoch [16/20], Batch [50/235], Loss_D: 0.2007, Loss_G: 4.0982 Epoch [16/20], Batch [100/235], Loss_D: 0.1016, Loss_G: 4.5341 Epoch [16/20], Batch [150/235], Loss_D: 0.0535, Loss_G: 5.0302 Epoch [16/20], Batch [200/235], Loss_D: 0.0535, Loss_G: 3.7612
End of Epoch 16 completed. Epoch [17/20], Batch [0/235], Loss_D: 0.0298, Loss_G: 4.4096 Epoch [17/20], Batch [50/235], Loss_D: 0.0466, Loss_G: 5.5681 Epoch [17/20], Batch [100/235], Loss_D: 0.9867, Loss_G: 1.6225 Epoch [17/20], Batch [150/235], Loss_D: 0.3657, Loss_G: 2.0744 Epoch [17/20], Batch [200/235], Loss_D: 0.2883, Loss_G: 3.9713
End of Epoch 17 completed. Epoch [18/20], Batch [0/235], Loss_D: 0.1419, Loss_G: 4.5721 Epoch [18/20], Batch [50/235], Loss_D: 0.1509, Loss_G: 4.5356 Epoch [18/20], Batch [100/235], Loss_D: 0.1508, Loss_G: 3.9809 Epoch [18/20], Batch [150/235], Loss_D: 0.0704, Loss_G: 3.7084 Epoch [18/20], Batch [200/235], Loss_D: 0.3868, Loss_G: 5.3765
End of Epoch 18 completed. Epoch [19/20], Batch [0/235], Loss_D: 0.1649, Loss_G: 4.5945 Epoch [19/20], Batch [50/235], Loss_D: 0.0568, Loss_G: 4.1682 Epoch [19/20], Batch [100/235], Loss_D: 0.0500, Loss_G: 4.2615 Epoch [19/20], Batch [150/235], Loss_D: 0.0479, Loss_G: 5.1583 Epoch [19/20], Batch [200/235], Loss_D: 0.0492, Loss_G: 5.5410
End of Epoch 19 completed.
WGAN-GP¶
In [23]:
# =======================================================
# Helper Module: View (for reshaping tensor in Generator)
# =======================================================
class View(nn.Module):
def __init__(self, shape):
super(View, self).__init__()
self.shape = shape
def forward(self, input):
return input.view(*self.shape)
In [24]:
# =======================================================
# Generator: DCGAN-Style for 64x64 Images
# =======================================================
class Generator(nn.Module):
def __init__(self, nz=100, ngf=64):
"""
Generator mapping a latent vector (dimension nz) to a 3x64x64 image.
It first projects the latent vector to a tensor of shape (ngf*8, 8, 8),
then upsamples through three layers:
8x8 -> 16x16,
16x16 -> 32x32,
32x32 -> 64x64.
The final convolution produces a 3-channel image.
"""
super(Generator, self).__init__()
self.nz = nz
self.main = nn.Sequential(
# Project latent vector and reshape:
nn.Linear(nz, ngf * 8 * 8 * 8),
nn.BatchNorm1d(ngf * 8 * 8 * 8),
nn.ReLU(True),
View((-1, ngf * 8, 8, 8)), # Shape: (ngf*8, 8, 8)
# Upsample: 8x8 -> 16x16
nn.ConvTranspose2d(ngf * 8, ngf * 4, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# Upsample: 16x16 -> 32x32
nn.ConvTranspose2d(ngf * 4, ngf * 2, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# Upsample: 32x32 -> 64x64
nn.ConvTranspose2d(ngf * 2, 3, kernel_size=4, stride=2, padding=1, bias=False),
nn.Tanh() # Output values in the range [-1,1]
)
def forward(self, input):
return self.main(input)
In [25]:
# =======================================================
# Discriminator: DCGAN-Style for 64x64 Images
# =======================================================
class Discriminator(nn.Module):
def __init__(self, ndf=64):
"""
The discriminator takes a 3x64x64 image and outputs a scalar probability.
It uses four convolutional layers:
- 64x64 -> 32x32,
- 32x32 -> 16x16,
- 16x16 -> 8x8,
- 8x8 -> 4x4.
A final convolution with kernel size 4 collapses the 4x4 feature map to a single value.
"""
super(Discriminator, self).__init__()
self.main = nn.Sequential(
# Input: 3 x 64 x 64 -> ndf x 32 x 32
nn.Conv2d(3, ndf, kernel_size=4, stride=2, padding=1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# 32x32 -> 16x16; output channels: ndf*2
nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# 16x16 -> 8x8; output channels: ndf*4
nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# 8x8 -> 4x4; output channels: ndf*8
nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
# Final layer: 4x4 -> 1, using kernel size 4
nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=1, padding=0, bias=False),
# No sigmoid
)
def forward(self, input):
out = self.main(input)
return out.view(-1)
In [32]:
import torch
import torch.autograd as autograd
def compute_gradient_penalty(D, real_samples, fake_samples, device):
batch_size = real_samples.size(0)
# Generate a random epsilon in [0,1] for each sample. Shape: [B, 1, 1, 1]
epsilon = torch.rand(batch_size, 1, 1, 1, device=device)
# Create interpolated samples:
interpolates = epsilon * real_samples + (1 - epsilon) * fake_samples
interpolates.requires_grad_(True)
# Compute discriminator output on the interpolated samples:
d_interpolates = D(interpolates)
# For each sample, set the gradient output to 1.
grad_outputs = torch.ones(d_interpolates.size(), device=device)
# Compute gradients of the outputs with respect to the interpolated samples.
gradients = autograd.grad(
outputs=d_interpolates,
inputs=interpolates,
grad_outputs=grad_outputs,
create_graph=True,
retain_graph=True,
only_inputs=True
)[0]
# Flatten gradients to shape [batch_size, -1] and compute L2 norm for each sample.
gradients = gradients.view(batch_size, -1)
grad_norms = torch.sqrt(torch.sum(gradients ** 2, dim=1))
# Compute the gradient penalty as the mean squared deviation of the gradients' norm from 1.
gradient_penalty = torch.mean((grad_norms - 1) ** 2)
return gradient_penalty
In [33]:
# =======================================================
# Loss Functions for WGAN-GP
# =======================================================
def critic_loss(D, G, real_samples, latents, lambda_gp, device):
fake_samples = G(latents)
# Critic scores: higher for real samples.
real_scores = D(real_samples)
fake_scores = D(fake_samples)
loss = fake_scores.mean() - real_scores.mean()
gp = compute_gradient_penalty(D, real_samples, fake_samples, device)
loss += lambda_gp * gp
return loss
def generator_loss(D, G, latents):
# Generator loss: try to maximize the critic's score on fake images.
return - D(G(latents)).mean()
In [36]:
def train_wgan_gp(num_epochs=50, nz=100, n_critic=5, lambda_gp=10, device='cpu'):
device = torch.device(device)
G = Generator(nz=nz, ngf=64).to(device)
D = Discriminator(ndf=64).to(device)
# Parameters
LR = 1e-4 # Initial learning rate
MIN_LR = 1e-6 # Minimum learning rate
DECAY_FACTOR = 1.00004 # Decay factor per epoch
# Set up Adam optimizers with beta1=0.5 (as in your code)
optimizerD = optim.Adam(D.parameters(), lr=LR, betas=(0.5, 0.999))
optimizerG = optim.Adam(G.parameters(), lr=LR, betas=(0.5, 0.999))
# Define a lambda function for learning rate decay.
def lr_lambda(epoch):
# This returns the multiplicative factor that gets multiplied by the initial lr.
# It ensures that lr never goes below MIN_LR.
return max((1 / DECAY_FACTOR) ** epoch, MIN_LR / LR)
# Set up the learning rate schedulers for both optimizers.
schedulerD = optim.lr_scheduler.LambdaLR(optimizerD, lr_lambda=lr_lambda)
schedulerG = optim.lr_scheduler.LambdaLR(optimizerG, lr_lambda=lr_lambda)
fixed_noise = torch.randn(25, nz, device=device)
last_loss_D = None
last_loss_G = None
for epoch in range(num_epochs):
for i, (images, _) in enumerate(train_dataloader):
images = images.to(device)
batch_size = images.size(0)
# Update critic n_critic times.
for _ in range(n_critic):
noise = torch.randn(batch_size, nz, device=device)
optimizerD.zero_grad()
loss_D = critic_loss(D, G, images, noise, lambda_gp, device)
loss_D.backward()
optimizerD.step()
# Update generator once.
noise = torch.randn(batch_size, nz, device=device)
optimizerG.zero_grad()
loss_G = generator_loss(D, G, noise)
loss_G.backward()
optimizerG.step()
if i % 50 == 0:
print(f"Epoch [{epoch}/{num_epochs}], Batch [{i}/{len(train_dataloader)}], "
f"Critic Loss: {loss_D.item():.4f}, Generator Loss: {loss_G.item():.4f}")
last_loss_D = loss_D.item()
last_loss_G = loss_G.item()
# End of epoch: generate and display 25 images from fixed noise.
with torch.no_grad():
fake_samples = G(fixed_noise).detach().cpu()
grid = utils.make_grid(fake_samples, nrow=5, normalize=True)
plt.figure(figsize=(8, 8))
plt.imshow(grid.permute(1, 2, 0).numpy())
plt.axis('off')
plt.title(f"Epoch {epoch} | Critic Loss: {last_loss_D:.4f} | Generator Loss: {last_loss_G:.4f}")
plt.show()
plt.close()
print(f"End of Epoch {epoch} completed.")
# At the end of the epoch:
schedulerD.step()
schedulerG.step()
return D, G
In [39]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
D, G = train_wgan_gp(num_epochs=25, nz=100, n_critic=5, lambda_gp=10, device=device)
Epoch [0/25], Batch [0/235], Critic Loss: -7.1815, Generator Loss: 3.3995 Epoch [0/25], Batch [50/235], Critic Loss: -319.1111, Generator Loss: 161.0143 Epoch [0/25], Batch [100/235], Critic Loss: -575.8983, Generator Loss: 293.8505 Epoch [0/25], Batch [150/235], Critic Loss: -947.6006, Generator Loss: 470.6101 Epoch [0/25], Batch [200/235], Critic Loss: -1310.2970, Generator Loss: 664.8742
End of Epoch 0 completed. Epoch [1/25], Batch [0/235], Critic Loss: -1603.1932, Generator Loss: 804.1373 Epoch [1/25], Batch [50/235], Critic Loss: -1037.7480, Generator Loss: 103.4770 Epoch [1/25], Batch [100/235], Critic Loss: -1613.4746, Generator Loss: 573.4979 Epoch [1/25], Batch [150/235], Critic Loss: -2640.9807, Generator Loss: 1264.7196 Epoch [1/25], Batch [200/235], Critic Loss: -3086.8433, Generator Loss: 1482.9570
End of Epoch 1 completed. Epoch [2/25], Batch [0/235], Critic Loss: 229.4240, Generator Loss: 1392.4490 Epoch [2/25], Batch [50/235], Critic Loss: -28.0986, Generator Loss: 1401.1266 Epoch [2/25], Batch [100/235], Critic Loss: -3754.0598, Generator Loss: 1808.0903 Epoch [2/25], Batch [150/235], Critic Loss: -4414.3828, Generator Loss: 2138.6362 Epoch [2/25], Batch [200/235], Critic Loss: 29.6658, Generator Loss: 1225.9301
End of Epoch 2 completed. Epoch [3/25], Batch [0/235], Critic Loss: 17.0579, Generator Loss: 1219.2595 Epoch [3/25], Batch [50/235], Critic Loss: 5.0561, Generator Loss: 1213.5356 Epoch [3/25], Batch [100/235], Critic Loss: -0.4950, Generator Loss: 1210.5740 Epoch [3/25], Batch [150/235], Critic Loss: -2.3364, Generator Loss: 1205.5253 Epoch [3/25], Batch [200/235], Critic Loss: -5.4682, Generator Loss: 1204.8875
End of Epoch 3 completed. Epoch [4/25], Batch [0/235], Critic Loss: -1.7772, Generator Loss: 1206.0137 Epoch [4/25], Batch [50/235], Critic Loss: -3.2213, Generator Loss: 1200.0872 Epoch [4/25], Batch [100/235], Critic Loss: -7.3166, Generator Loss: 1199.0438 Epoch [4/25], Batch [150/235], Critic Loss: -6.2481, Generator Loss: 1196.1101 Epoch [4/25], Batch [200/235], Critic Loss: -6.9179, Generator Loss: 1193.1678
End of Epoch 4 completed. Epoch [5/25], Batch [0/235], Critic Loss: -5.9082, Generator Loss: 1195.7135 Epoch [5/25], Batch [50/235], Critic Loss: -7.8350, Generator Loss: 1186.9421 Epoch [5/25], Batch [100/235], Critic Loss: -9.1983, Generator Loss: 1174.1233 Epoch [5/25], Batch [150/235], Critic Loss: -8.1507, Generator Loss: 1183.1907 Epoch [5/25], Batch [200/235], Critic Loss: -7.0861, Generator Loss: 1164.1135
End of Epoch 5 completed. Epoch [6/25], Batch [0/235], Critic Loss: -6.0164, Generator Loss: 1167.7697 Epoch [6/25], Batch [50/235], Critic Loss: -8.0003, Generator Loss: 1160.3273 Epoch [6/25], Batch [100/235], Critic Loss: -8.2061, Generator Loss: 1158.1433 Epoch [6/25], Batch [150/235], Critic Loss: -8.6398, Generator Loss: 1143.9407 Epoch [6/25], Batch [200/235], Critic Loss: -12.2942, Generator Loss: 1142.6199
End of Epoch 6 completed. Epoch [7/25], Batch [0/235], Critic Loss: -9.0067, Generator Loss: 1134.8313 Epoch [7/25], Batch [50/235], Critic Loss: -11.3545, Generator Loss: 1120.7852 Epoch [7/25], Batch [100/235], Critic Loss: -10.2414, Generator Loss: 1113.6353 Epoch [7/25], Batch [150/235], Critic Loss: -13.3492, Generator Loss: 1095.5884 Epoch [7/25], Batch [200/235], Critic Loss: -14.7403, Generator Loss: 1083.0547
End of Epoch 7 completed. Epoch [8/25], Batch [0/235], Critic Loss: -12.3538, Generator Loss: 1081.8322 Epoch [8/25], Batch [50/235], Critic Loss: -11.8157, Generator Loss: 1062.2549 Epoch [8/25], Batch [100/235], Critic Loss: -13.7031, Generator Loss: 1061.2454 Epoch [8/25], Batch [150/235], Critic Loss: -15.6746, Generator Loss: 1038.0032 Epoch [8/25], Batch [200/235], Critic Loss: -12.2785, Generator Loss: 1036.6870
End of Epoch 8 completed. Epoch [9/25], Batch [0/235], Critic Loss: -12.3723, Generator Loss: 1021.6680 Epoch [9/25], Batch [50/235], Critic Loss: -9.6453, Generator Loss: 1016.1304 Epoch [9/25], Batch [100/235], Critic Loss: -13.5530, Generator Loss: 1014.8564 Epoch [9/25], Batch [150/235], Critic Loss: -15.0983, Generator Loss: 991.1080 Epoch [9/25], Batch [200/235], Critic Loss: -13.8649, Generator Loss: 984.4841
End of Epoch 9 completed. Epoch [10/25], Batch [0/235], Critic Loss: -9.6877, Generator Loss: 986.0001 Epoch [10/25], Batch [50/235], Critic Loss: -11.9478, Generator Loss: 973.5349 Epoch [10/25], Batch [100/235], Critic Loss: -12.6790, Generator Loss: 959.2227 Epoch [10/25], Batch [150/235], Critic Loss: -14.4752, Generator Loss: 952.8516 Epoch [10/25], Batch [200/235], Critic Loss: -12.7882, Generator Loss: 948.1614
End of Epoch 10 completed. Epoch [11/25], Batch [0/235], Critic Loss: -10.8146, Generator Loss: 943.7811 Epoch [11/25], Batch [50/235], Critic Loss: -14.7517, Generator Loss: 940.8361 Epoch [11/25], Batch [100/235], Critic Loss: -16.9774, Generator Loss: 943.5679 Epoch [11/25], Batch [150/235], Critic Loss: -16.0654, Generator Loss: 939.0920 Epoch [11/25], Batch [200/235], Critic Loss: -20.3614, Generator Loss: 936.4711
End of Epoch 11 completed. Epoch [12/25], Batch [0/235], Critic Loss: -18.1832, Generator Loss: 928.4172 Epoch [12/25], Batch [50/235], Critic Loss: -15.1205, Generator Loss: 928.4465 Epoch [12/25], Batch [100/235], Critic Loss: -15.1267, Generator Loss: 939.8590 Epoch [12/25], Batch [150/235], Critic Loss: -17.2297, Generator Loss: 934.1822 Epoch [12/25], Batch [200/235], Critic Loss: -19.3050, Generator Loss: 939.5249
End of Epoch 12 completed. Epoch [13/25], Batch [0/235], Critic Loss: -12.7700, Generator Loss: 928.9781 Epoch [13/25], Batch [50/235], Critic Loss: -15.1378, Generator Loss: 934.3294 Epoch [13/25], Batch [100/235], Critic Loss: -14.2550, Generator Loss: 935.8308 Epoch [13/25], Batch [150/235], Critic Loss: -21.9916, Generator Loss: 936.1085 Epoch [13/25], Batch [200/235], Critic Loss: -17.9902, Generator Loss: 944.5997
End of Epoch 13 completed. Epoch [14/25], Batch [0/235], Critic Loss: -14.3817, Generator Loss: 929.7521 Epoch [14/25], Batch [50/235], Critic Loss: -18.0552, Generator Loss: 944.4028 Epoch [14/25], Batch [100/235], Critic Loss: -11.2363, Generator Loss: 935.4798 Epoch [14/25], Batch [150/235], Critic Loss: -20.1454, Generator Loss: 942.7131 Epoch [14/25], Batch [200/235], Critic Loss: -22.4876, Generator Loss: 942.3488
End of Epoch 14 completed. Epoch [15/25], Batch [0/235], Critic Loss: -16.9694, Generator Loss: 936.5730 Epoch [15/25], Batch [50/235], Critic Loss: -23.6465, Generator Loss: 943.0557 Epoch [15/25], Batch [100/235], Critic Loss: -13.8048, Generator Loss: 942.3694 Epoch [15/25], Batch [150/235], Critic Loss: -17.4215, Generator Loss: 935.6983 Epoch [15/25], Batch [200/235], Critic Loss: -18.4353, Generator Loss: 937.7614
End of Epoch 15 completed. Epoch [16/25], Batch [0/235], Critic Loss: -8.8008, Generator Loss: 936.5846 Epoch [16/25], Batch [50/235], Critic Loss: -21.4678, Generator Loss: 955.0105 Epoch [16/25], Batch [100/235], Critic Loss: -12.5794, Generator Loss: 943.1880 Epoch [16/25], Batch [150/235], Critic Loss: -23.7656, Generator Loss: 939.2748 Epoch [16/25], Batch [200/235], Critic Loss: -26.7899, Generator Loss: 936.5636
End of Epoch 16 completed. Epoch [17/25], Batch [0/235], Critic Loss: -11.8148, Generator Loss: 929.4254 Epoch [17/25], Batch [50/235], Critic Loss: -18.5867, Generator Loss: 941.4560 Epoch [17/25], Batch [100/235], Critic Loss: -11.9151, Generator Loss: 938.3667 Epoch [17/25], Batch [150/235], Critic Loss: -17.8414, Generator Loss: 943.4852 Epoch [17/25], Batch [200/235], Critic Loss: -14.1036, Generator Loss: 949.7693
End of Epoch 17 completed. Epoch [18/25], Batch [0/235], Critic Loss: -28.9961, Generator Loss: 940.0487 Epoch [18/25], Batch [50/235], Critic Loss: -21.4629, Generator Loss: 939.6287 Epoch [18/25], Batch [100/235], Critic Loss: -21.2551, Generator Loss: 948.8262 Epoch [18/25], Batch [150/235], Critic Loss: -14.8244, Generator Loss: 939.2914 Epoch [18/25], Batch [200/235], Critic Loss: -16.4426, Generator Loss: 947.8840
End of Epoch 18 completed. Epoch [19/25], Batch [0/235], Critic Loss: -18.4954, Generator Loss: 946.0103 Epoch [19/25], Batch [50/235], Critic Loss: -13.0874, Generator Loss: 941.6476 Epoch [19/25], Batch [100/235], Critic Loss: -8.4792, Generator Loss: 936.5409 Epoch [19/25], Batch [150/235], Critic Loss: -8.4111, Generator Loss: 948.5744 Epoch [19/25], Batch [200/235], Critic Loss: -8.8216, Generator Loss: 949.1342
End of Epoch 19 completed. Epoch [20/25], Batch [0/235], Critic Loss: -9.7134, Generator Loss: 945.5421 Epoch [20/25], Batch [50/235], Critic Loss: -6.9775, Generator Loss: 948.1437 Epoch [20/25], Batch [100/235], Critic Loss: -16.3411, Generator Loss: 947.7766 Epoch [20/25], Batch [150/235], Critic Loss: -29.5035, Generator Loss: 953.3276 Epoch [20/25], Batch [200/235], Critic Loss: -12.2876, Generator Loss: 940.1945
End of Epoch 20 completed. Epoch [21/25], Batch [0/235], Critic Loss: -14.1305, Generator Loss: 933.7478 Epoch [21/25], Batch [50/235], Critic Loss: -17.7026, Generator Loss: 942.4226 Epoch [21/25], Batch [100/235], Critic Loss: -18.2549, Generator Loss: 942.1077 Epoch [21/25], Batch [150/235], Critic Loss: -20.6455, Generator Loss: 939.1882 Epoch [21/25], Batch [200/235], Critic Loss: -12.8188, Generator Loss: 952.3013
End of Epoch 21 completed. Epoch [22/25], Batch [0/235], Critic Loss: -12.8006, Generator Loss: 944.2827 Epoch [22/25], Batch [50/235], Critic Loss: -15.2925, Generator Loss: 948.0229 Epoch [22/25], Batch [100/235], Critic Loss: -18.3503, Generator Loss: 947.7360 Epoch [22/25], Batch [150/235], Critic Loss: -13.2577, Generator Loss: 941.4637 Epoch [22/25], Batch [200/235], Critic Loss: -12.0882, Generator Loss: 953.0192
End of Epoch 22 completed. Epoch [23/25], Batch [0/235], Critic Loss: -1.3852, Generator Loss: 937.0220 Epoch [23/25], Batch [50/235], Critic Loss: -8.8576, Generator Loss: 954.6143 Epoch [23/25], Batch [100/235], Critic Loss: -13.4211, Generator Loss: 946.9574 Epoch [23/25], Batch [150/235], Critic Loss: -12.9877, Generator Loss: 948.9841 Epoch [23/25], Batch [200/235], Critic Loss: -17.1787, Generator Loss: 944.7324
End of Epoch 23 completed. Epoch [24/25], Batch [0/235], Critic Loss: -13.0251, Generator Loss: 942.5323 Epoch [24/25], Batch [50/235], Critic Loss: -28.6021, Generator Loss: 955.1877 Epoch [24/25], Batch [100/235], Critic Loss: -19.4927, Generator Loss: 946.7385
--------------------------------------------------------------------------- KeyboardInterrupt Traceback (most recent call last) Cell In[39], line 2 1 device = 'cuda' if torch.cuda.is_available() else 'cpu' ----> 2 D, G = train_wgan_gp(num_epochs=25, nz=100, n_critic=5, lambda_gp=10, device=device) Cell In[36], line 41, in train_wgan_gp(num_epochs, nz, n_critic, lambda_gp, device) 39 noise = torch.randn(batch_size, nz, device=device) 40 optimizerD.zero_grad() ---> 41 loss_D = critic_loss(D, G, images, noise, lambda_gp, device) 42 loss_D.backward() 43 optimizerD.step() Cell In[33], line 10, in critic_loss(D, G, real_samples, latents, lambda_gp, device) 8 fake_scores = D(fake_samples) 9 loss = fake_scores.mean() - real_scores.mean() ---> 10 gp = compute_gradient_penalty(D, real_samples, fake_samples, device) 11 loss += lambda_gp * gp 12 return loss Cell In[32], line 14, in compute_gradient_penalty(D, real_samples, fake_samples, device) 11 interpolates.requires_grad_(True) 13 # Compute discriminator output on the interpolated samples: ---> 14 d_interpolates = D(interpolates) 16 # For each sample, set the gradient output to 1. 17 grad_outputs = torch.ones(d_interpolates.size(), device=device) File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs) 1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1735 else: -> 1736 return self._call_impl(*args, **kwargs) File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs) 1742 # If we don't have any hooks, we want to skip the rest of the logic in 1743 # this function, and just call forward. 1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1745 or _global_backward_pre_hooks or _global_backward_hooks 1746 or _global_forward_hooks or _global_forward_pre_hooks): -> 1747 return forward_call(*args, **kwargs) 1749 result = None 1750 called_always_called_hooks = set() Cell In[25], line 38, in Discriminator.forward(self, input) 37 def forward(self, input): ---> 38 out = self.main(input) 39 return out.view(-1) File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs) 1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1735 else: -> 1736 return self._call_impl(*args, **kwargs) File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs) 1742 # If we don't have any hooks, we want to skip the rest of the logic in 1743 # this function, and just call forward. 1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1745 or _global_backward_pre_hooks or _global_backward_hooks 1746 or _global_forward_hooks or _global_forward_pre_hooks): -> 1747 return forward_call(*args, **kwargs) 1749 result = None 1750 called_always_called_hooks = set() File ~/.local/lib/python3.10/site-packages/torch/nn/modules/container.py:250, in Sequential.forward(self, input) 248 def forward(self, input): 249 for module in self: --> 250 input = module(input) 251 return input File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs) 1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1735 else: -> 1736 return self._call_impl(*args, **kwargs) File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs) 1742 # If we don't have any hooks, we want to skip the rest of the logic in 1743 # this function, and just call forward. 1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1745 or _global_backward_pre_hooks or _global_backward_hooks 1746 or _global_forward_hooks or _global_forward_pre_hooks): -> 1747 return forward_call(*args, **kwargs) 1749 result = None 1750 called_always_called_hooks = set() File ~/.local/lib/python3.10/site-packages/torch/nn/modules/conv.py:554, in Conv2d.forward(self, input) 553 def forward(self, input: Tensor) -> Tensor: --> 554 return self._conv_forward(input, self.weight, self.bias) File ~/.local/lib/python3.10/site-packages/torch/nn/modules/conv.py:549, in Conv2d._conv_forward(self, input, weight, bias) 537 if self.padding_mode != "zeros": 538 return F.conv2d( 539 F.pad( 540 input, self._reversed_padding_repeated_twice, mode=self.padding_mode (...) 547 self.groups, 548 ) --> 549 return F.conv2d( 550 input, weight, bias, self.stride, self.padding, self.dilation, self.groups 551 ) KeyboardInterrupt: